/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include "RClass.h"

#include <limits>
#include <string>

#include "CFGMutation.h"
#include "ConstantPropagationAnalysis.h"
#include "ControlFlow.h"
#include "DexAnnotation.h"
#include "DexInstruction.h"
#include "DexUtil.h"
#include "EditableCfgAdapter.h"
#include "GlobalConfig.h"
#include "InitClassesWithSideEffects.h"
#include "InitDeps.h"
#include "LiveRange.h"
#include "LocalPointersAnalysis.h"
#include "MethodOverrideGraph.h"
#include "ObjectSensitiveDce.h"
#include "RedexResources.h"
#include "Resolver.h"
#include "Show.h"
#include "SideEffectSummary.h"
#include "Timer.h"
#include "Trace.h"
#include "Walkers.h"

namespace {
// See the following, which has informed class names to look for.
// https://cs.android.com/android/platform/superproject/main/+/main:frameworks/base/tools/aapt2/Resource.cpp;l=107
//
// Note that attr-private and macro omitted here per
// https://cs.android.com/android/_/android/platform/frameworks/base/+/326e35ffaf0ee1e3d07c977217f4e600088fd9d5
// https://cs.android.com/android/platform/superproject/main/+/main:frameworks/base/tools/aapt2/java/JavaClassGenerator.cpp;l=619
static const UnorderedSet<std::string> r_inner_classes{
    {"/R$anim"},         {"/R$animator"},  {"/R$array"},      {"/R$attr"},
    {"/R$bool"},         {"/R$color"},     {"/R$dimen"},      {"/R$drawable"},
    {"/R$font"},         {"/R$fraction"},  {"/R$id"},         {"/R$integer"},
    {"/R$interpolator"}, {"/R$layout"},    {"/R$menu"},       {"/R$mipmap"},
    {"/R$navigation"},   {"/R$plurals"},   {"/R$raw"},        {"/R$string"},
    {"/R$style"},        {"/R$styleable"}, {"/R$transition"}, {"/R$xml"}};

// Crude check for if the name matches the pattern of autogenerated R class.
// Slightly more permissive than necessary for ease of test cases.
bool is_resource_class_name(const std::string_view& cls_name) {
  for (const auto& s : UnorderedIterable(r_inner_classes)) {
    if (cls_name.find(s) != std::string::npos) {
      return true;
    }
  }
  return false;
}

// Welcome to clowntown.
bool is_styleable(const DexClass* cls) {
  const auto c_name = cls->get_name()->str();
  const auto d_name = cls->get_deobfuscated_name_or_empty();
  return c_name.find("/R$styleable") != std::string::npos ||
         d_name.find("/R$styleable") != std::string::npos;
}

bool is_zero_arg_constructor(const DexMethodRef* ref) {
  auto* proto = ref->get_proto();
  return method::is_init(ref) && proto->is_void() && proto->get_args()->empty();
}

// For instrumented builds, R class <clinit> method and constructors can have
// additional static method calls for tracking purposes. Because this does not
// fit the standard form, or the expectations of what we pass to OSDCE, be
// permissive under this situation and skip cleanup.
bool is_tolerable_instrumentation_invoke(
    const IRInstruction* insn, const ResourceConfig& global_resources_config) {
  if (!global_resources_config.cleanup_r_class_rewriting) {
    // If cleanup is turned off, be more permissive to allow InstrumentPass's
    // onMethodBegin to be called. The following is a very coarse check for
    // that method, for simplicity's sake.
    return opcode::is_invoke_static(insn->opcode()) &&
           insn->get_method()->get_proto()->is_void();
  }
  return false;
}

// R inner classes are expected to simply set up static final fields of the
// class. There should be no invokes, but a class might dereference fields from
// other classes (which should be fine).
bool valid_r_class_clinit(const DexMethod* clinit,
                          const ResourceConfig& global_resources_config) {
  const auto* const code = clinit->get_code();
  if (code == nullptr) {
    return true;
  }
  bool is_valid{true};
  cfg_adapter::iterate(code, [&](const MethodItemEntry& mie) {
    if (mie.type == MFLOW_OPCODE) {
      if (mie.insn->has_method() && !is_tolerable_instrumentation_invoke(
                                        mie.insn, global_resources_config)) {
        is_valid = false;
        return cfg_adapter::LoopExit::LOOP_BREAK;
      }
    }
    return cfg_adapter::LoopExit::LOOP_CONTINUE;
  });
  return is_valid;
}

// Should be an empty constructor that simply loads the "this" obj, calls the
// super Object constructor and returns.
bool valid_r_class_init(const DexMethod* init,
                        const ResourceConfig& global_resources_config) {
  if (!is_zero_arg_constructor(init)) {
    return false;
  }
  const auto* const code = init->get_code();
  if (code == nullptr) {
    return true;
  }
  bool is_valid{true};
  cfg_adapter::iterate(code, [&](const MethodItemEntry& mie) {
    if (mie.type == MFLOW_OPCODE) {
      auto op = mie.insn->opcode();
      if (mie.insn->has_method()) {
        if (!is_zero_arg_constructor(mie.insn->get_method()) &&
            !is_tolerable_instrumentation_invoke(mie.insn,
                                                 global_resources_config)) {
          is_valid = false;
          return cfg_adapter::LoopExit::LOOP_BREAK;
        }
      } else if (!opcode::is_a_load_param(op) && !opcode::is_a_return(op) &&
                 !opcode::is_a_const(op)) {
        is_valid = false;
        return cfg_adapter::LoopExit::LOOP_BREAK;
      }
    }
    return cfg_adapter::LoopExit::LOOP_CONTINUE;
  });
  return is_valid;
}

// See
// https://github.com/facebook/buck/blob/main/src/com/facebook/buck/android/MergeAndroidResourcesStep.java#L385
// https://github.com/facebook/buck/commit/ec583c559239256ba0478d4bfdfc8d2c21426c4b
bool is_customized_resource_class_name(
    const std::string_view& cls_name,
    const ResourceConfig& global_resources_config) {
  for (const auto& s :
       UnorderedIterable(global_resources_config.customized_r_classes)) {
    if (cls_name == s) {
      return true;
    }
  }
  return false;
}

// A customized resource class is one that is not prepared by Android toolchain
// (i.e. by buck), and must be configured via Redex config.
bool is_customized_resource_class(
    const DexClass* cls, const ResourceConfig& global_resources_config) {
  const auto c_name = cls->get_name()->str();
  const auto d_name = cls->get_deobfuscated_name_or_empty();
  return is_customized_resource_class_name(c_name, global_resources_config) ||
         is_customized_resource_class_name(d_name, global_resources_config);
}

bool is_external_ref(const DexFieldRef* field_ref) {
  auto* field_cls = type_class(field_ref->get_class());
  if (field_cls == nullptr) {
    return false;
  }
  return field_cls->is_external();
}
} // namespace

namespace resources {
// A "non-customized" class is one that is prepared by the Android toolchain.
bool is_non_customized_r_class(const DexClass* cls) {
  const auto c_name = cls->get_name()->str();
  const auto d_name = cls->get_deobfuscated_name_or_empty();
  return is_resource_class_name(c_name) || is_resource_class_name(d_name);
}

void prepare_r_classes(DexStoresVector& stores,
                       const GlobalConfig& global_config) {
  auto scope = build_class_scope(stores);
  RClassReader r_class_reader(global_config);
  for (auto* cls : scope) {
    if (r_class_reader.is_r_class(cls)) {
      auto* clinit = cls->get_clinit();
      if (clinit != nullptr && !clinit->rstate.should_not_outline()) {
        TRACE(OPTRES, 1, "Disabling outlining for %s", SHOW(clinit));
        clinit->rstate.set_no_outlining();
      }
    }
  }
}

// R inner classes should have at most a <clinit> method, no others, and that
// <clinit> should be simple (i.e. not calling any other methods).
bool RClassReader::valid_r_class_structure(const DexClass* cls) const {
  if (!cls->get_vmethods().empty()) {
    return false;
  }
  // Should be at most clinit and zero arg constructor
  const auto& methods = cls->get_dmethods();
  if (methods.size() > 2) {
    return false;
  }
  for (auto* m : methods) {
    if (m == cls->get_clinit()) {
      if (!valid_r_class_clinit(m, m_global_resources_config)) {
        return false;
      }
    } else if (is_constructor(m)) {
      if (!valid_r_class_init(m, m_global_resources_config)) {
        return false;
      }
    } else {
      return false;
    }
  }
  return true;
}

bool RClassReader::is_r_class(const DexClass* cls) const {
  if (is_customized_resource_class(cls, m_global_resources_config)) {
    // Customized classes will have fewer validation checks; they may have some
    // extra/inconsequential getters.
    auto* clinit = cls->get_clinit();
    if (clinit != nullptr) {
      always_assert_log(valid_r_class_clinit(clinit, m_global_resources_config),
                        "<clinit> unsupported of custom R class %s", SHOW(cls));
    }
    return true;
  }
  if (is_non_customized_r_class(cls)) {
    always_assert_log(
        valid_r_class_structure(cls),
        "%s is a R inner class but does not have the required structure.",
        SHOW(cls));
    return true;
  }
  return false;
}

bool RClassReader::is_r_class(const DexFieldRef* field_ref) const {
  auto* field_cls = type_class(field_ref->get_class());
  if (field_cls == nullptr) {
    return false;
  }
  return is_r_class(field_cls);
}

namespace cp = constant_propagation;
using ArrayAnalyzer = InstructionAnalyzerCombiner<cp::ClinitFieldAnalyzer,
                                                  cp::LocalArrayAnalyzer,
                                                  cp::HeapEscapeAnalyzer,
                                                  cp::PrimitiveAnalyzer>;

FieldArrayValues RClassReader::analyze_clinit(
    DexClass* cls, const FieldArrayValues& known_field_values) const {
  FieldArrayValues values;
  auto* clinit = cls->get_clinit();
  if (clinit == nullptr || clinit->get_code() == nullptr) {
    return values;
  }
  always_assert(clinit->get_code()->cfg_built());
  auto& cfg = clinit->get_code()->cfg();
  cfg.calculate_exit_block();

  cp::intraprocedural::FixpointIterator intra_cp(
      /* cp_state */ nullptr, cfg,
      ArrayAnalyzer(cls->get_type(), nullptr, nullptr, nullptr));
  intra_cp.run(ConstantEnvironment());

  Lazy<live_range::UseDefChains> udchain(
      [&]() { return live_range::Chains(cfg).get_use_def_chains(); });

  UnorderedSet<DexField*> locally_built_fields;
  for (auto* block : cfg.blocks()) {
    auto env = intra_cp.get_entry_state_at(block);
    auto last_insn = block->get_last_insn();
    for (auto& mie : InstructionIterable(block)) {
      auto* insn = mie.insn;
      if (insn->opcode() == OPCODE_SPUT_OBJECT &&
          insn->get_field()->get_class() == clinit->get_class()) {
        // NOTE: this entire job may be best performed as interprocedural.
        // Some day.
        auto* field_type = insn->get_field()->get_type();
        auto* field_def = insn->get_field()->as_def();
        always_assert(type::is_array(field_type));
        const DexType* element_type =
            type::get_array_component_type(field_type);
        always_assert_log(type::is_int(element_type),
                          "R clinit array are expected to be [I. Got %s",
                          SHOW(field_type));

        auto array_domain =
            env.get_pointee<ConstantValueArrayDomain>(insn->src(0));
        if (!array_domain.is_value()) {
          // assert that this is coming from a different array that is already
          // known; if so then this "reuse" does not need to be tracked
          // specially.
          const auto& defs = udchain->at((live_range::Use){insn, 0});
          always_assert_log(defs.size() == 1,
                            "Expecting single def flowing into field %s in %s ",
                            SHOW(insn->get_field()), SHOW(cfg));
          IRInstruction* def = *defs.begin();
          if (opcode::is_move_result_pseudo_object(def->opcode())) {
            auto it =
                cfg.primary_instruction_of_move_result(cfg.find_insn(def));
            def = it->insn;
          }
          always_assert_log(def->opcode() == OPCODE_SGET_OBJECT,
                            "Unsupported array definition at %s in %s",
                            SHOW(def), SHOW(cfg));
          auto* source_field = def->get_field();
          if (!is_external_ref(source_field)) {
            always_assert_log(known_field_values.count(source_field) > 0,
                              "Field %s was not analyzed",
                              SHOW(source_field));
            // Give callers a consistent view of the used IDs for fields that
            // reference other fields.
            auto known_values = known_field_values.at(source_field);
            values.emplace(field_def, known_values);
          }
        } else {
          locally_built_fields.emplace(field_def);
        }
      }
      intra_cp.analyze_instruction(insn, &env, insn == last_insn->insn);
    }
  }

  auto env = intra_cp.get_exit_state_at(cfg.exit_block());
  for (auto* f : UnorderedIterable(locally_built_fields)) {
    auto field_value = env.get(f);
    auto heap_ptr = field_value.maybe_get<AbstractHeapPointer>();
    always_assert_log(heap_ptr && heap_ptr->is_value(),
                      "Could not determine field value %s", SHOW(f));
    auto array_domain = env.get_pointee<ConstantValueArrayDomain>(*heap_ptr);
    always_assert(array_domain.is_value());
    std::vector<uint32_t> array_content;
    auto len = array_domain.length();
    for (size_t i = 0; i < len; ++i) {
      auto value = array_domain.get(static_cast<uint32_t>(i))
                       .maybe_get<SignedConstantDomain>();
      always_assert_log(value,
                        "%s is not in the SignedConstantDomain, "
                        "stored at %zu in %s:\n%s",
                        SHOW(array_domain.get(i)), i, SHOW(clinit), SHOW(cfg));
      auto cst = value->get_constant();
      always_assert_log(cst, "%s is not a constant", SHOW(*value));
      array_content.emplace_back(static_cast<uint32_t>(*cst));
    }
    values.emplace(f, std::move(array_content));
  }
  return values;
}

void RClassReader::ordered_r_class_iteration(
    const Scope& scope, const std::function<void(DexClass*)>& callback) const {
  Scope apply_scope;
  for (auto* cls : scope) {
    if (is_r_class(cls)) {
      apply_scope.emplace_back(cls);
    }
  }
  size_t clinit_cycles = 0;
  Scope ordered_scope =
      init_deps::reverse_tsort_by_clinit_deps(apply_scope, clinit_cycles);
  always_assert_log(clinit_cycles == 0, "Found %zu clinit cycles",
                    clinit_cycles);

  for (auto* cls : ordered_scope) {
    callback(cls);
  }
}

void RClassReader::extract_resource_ids_from_static_arrays(
    const Scope& scope,
    const UnorderedSet<DexField*>& array_fields,
    UnorderedSet<uint32_t>* out_values) const {
  Timer t("extract_resource_ids_from_static_arrays");
  FieldArrayValues field_values;
  ordered_r_class_iteration(scope, [&](DexClass* cls) {
    auto class_state = analyze_clinit(cls, field_values);
    field_values.insert(class_state.begin(), class_state.end());
  });
  for (auto&& [f, vec] : field_values) {
    auto* field_def = f->as_def();
    if (field_def != nullptr && array_fields.count(field_def) > 0) {
      out_values->insert(vec.begin(), vec.end());
    }
  }
}

void RClassWriter::remap_resource_class_scalars(
    DexStoresVector& stores,
    const std::map<uint32_t, uint32_t>& old_to_remapped_ids) const {
  auto scope = build_class_scope(stores);
  RClassReader r_class_reader(m_global_resources_config);
  for (auto* clazz : scope) {
    if (r_class_reader.is_r_class(clazz)) {
      const std::vector<DexField*>& fields = clazz->get_sfields();
      for (const auto& field : fields) {
        if (!type::is_int(field->get_type())) {
          continue;
        }
        auto encoded_val = field->get_static_value()->value();
        always_assert(encoded_val <= std::numeric_limits<int32_t>::max());
        auto encoded_int = (uint32_t)encoded_val;
        if (encoded_int > PACKAGE_RESID_START &&
            (old_to_remapped_ids.count(encoded_int) != 0u)) {
          field->get_static_value()->value(old_to_remapped_ids.at(encoded_int));
        }
      }
    }
  }
  walk::parallel::opcodes(scope, [&](DexMethod*, IRInstruction* insn) {
    if (insn->opcode() == IOPCODE_R_CONST) {
      int64_t old = insn->get_literal();
      always_assert_log(old_to_remapped_ids.count(old),
                        "Encountered resource ID %llx which cannot be "
                        "remapped",
                        (long long)old);
      always_assert_log(old <= std::numeric_limits<uint32_t>::max(),
                        "Resource ID %llx needs to fit in 32 bits",
                        (long long)old);
      always_assert_log(old >= 0, "Resource ID %llx must be positive",
                        (long long)old);
      insn->set_literal(old_to_remapped_ids.at((uint32_t)old));
    }
  });
}

namespace {
// Writes a remapped vector of new values to output, returning whether or not it
// is actually different.
bool remap_array(const std::vector<uint32_t>& original_values,
                 const std::map<uint32_t, uint32_t>& old_to_remapped_ids,
                 const bool zero_out_values,
                 std::vector<uint32_t>* new_values) {
  bool changed{false};
  for (auto payload : original_values) {
    if (payload > PACKAGE_RESID_START) {
      bool keep = old_to_remapped_ids.count(payload) != 0u;
      if (keep) {
        auto remapped = old_to_remapped_ids.at(payload);
        new_values->emplace_back(remapped);
        changed = changed || remapped != payload;
      } else {
        changed = true;
        // For styleable, we avoid actually deleting entries since
        // there are offsets that will point to the wrong positions
        // in the array. Instead, we zero out the values.
        if (zero_out_values) {
          new_values->emplace_back(0);
        }
      }
    } else {
      new_values->emplace_back(payload);
    }
  }
  return changed;
}

void perform_dce(Scope& scope) {
  auto method_override_graph = method_override_graph::build_graph(scope);
  init_classes::InitClassesWithSideEffects init_classes_with_side_effects(
      scope, true, method_override_graph.get());
  // Assume no pure methods or summaries for external/framework code for
  // simplicity. OSDCE should make conservative assumptions in the face of this.
  UnorderedSet<DexMethodRef*> pure_methods;
  local_pointers::SummaryMap escape_summaries;
  side_effects::SummaryMap effect_summaries;
  ObjectSensitiveDce impl(scope,
                          &init_classes_with_side_effects,
                          pure_methods,
                          *method_override_graph,
                          0, /* should not be encountering virtual calls here */
                          &escape_summaries,
                          &effect_summaries);
  impl.dce();
  const auto& stats = impl.get_stats();
  TRACE(OPTRES, 2, "Pruned %zu instruction(s); R class scope size = %zu",
        stats.removed_instructions, scope.size());
}
} // namespace

size_t RClassWriter::remap_resource_class_clinit(
    const DexClass* cls,
    const FieldArrayValues& field_values,
    const std::map<uint32_t, uint32_t>& old_to_remapped_ids) const {
  IRCode* ir_code = cls->get_clinit()->get_code();
  always_assert(ir_code->cfg_built());

  // For styleable, we avoid actually deleting entries since there are offsets
  // that will point to the wrong positions in the array. Instead, zero out the
  // values.
  bool zero_out_values = is_styleable(cls);
  FieldArrayValues pending_new_values;
  for (auto&& [f, vec] : field_values) {
    std::vector<uint32_t> new_values;
    if (remap_array(vec, old_to_remapped_ids, zero_out_values, &new_values)) {
      pending_new_values.emplace(f, new_values);
    }
  }

  if (pending_new_values.empty()) {
    return 0;
  }
  auto& cfg = ir_code->cfg();
  cfg::CFGMutation mutation(cfg);

  // Regenerate the filling of the changed fields, leaving the old code behind.
  std::map<int32_t, reg_t> values_to_reg;
  auto get_register_for_value = [&](int32_t value) {
    auto search = values_to_reg.find(value);
    if (search != values_to_reg.end()) {
      return search->second;
    }
    auto reg = cfg.allocate_temp();
    values_to_reg.emplace(value, reg);
    return reg;
  };
  auto iterable = cfg::InstructionIterable(cfg);
  for (auto it = iterable.begin(); it != iterable.end(); ++it) {
    auto* insn = it->insn;
    if (insn->opcode() == OPCODE_SPUT_OBJECT &&
        pending_new_values.count(insn->get_field()) > 0) {
      auto new_values = pending_new_values.at(insn->get_field());
      // Generate new instructions to create array of new size, fill it, move it
      // to field. Something like the following, except CONST instructions will
      // be slapped into the beginning of entry block:
      // OPCODE: CONST v0, 10
      // OPCODE: NEW_ARRAY v0, [I
      // OPCODE: IOPCODE_MOVE_RESULT_PSEUDO_OBJECT v0
      // OPCODE: FILL_ARRAY_DATA v0, <data>
      //   fill-array-data-payload { [10 x 4] { ... yada yada ... } }
      // OPCODE: SPUT_OBJECT v0, Lcom/redextest/R;.six:[I
      auto size_reg =
          get_register_for_value(static_cast<int32_t>(new_values.size()));
      auto array_reg = cfg.allocate_temp();

      auto* new_array = new IRInstruction(OPCODE_NEW_ARRAY);
      new_array->set_src(0, size_reg);
      new_array->set_type(insn->get_field()->get_type());
      auto* move_result_pseudo =
          new IRInstruction(IOPCODE_MOVE_RESULT_PSEUDO_OBJECT);
      move_result_pseudo->set_dest(array_reg);
      if (!new_values.empty()) {
        auto* fill_array_data = new IRInstruction(OPCODE_FILL_ARRAY_DATA);
        fill_array_data->set_src(0, array_reg);
        auto op_data = encode_fill_array_data_payload(new_values);
        fill_array_data->set_data(std::move(op_data));

        mutation.insert_before(
            cfg.find_insn(insn),
            {new_array, move_result_pseudo, fill_array_data});
      } else {
        mutation.insert_before(cfg.find_insn(insn),
                               {new_array, move_result_pseudo});
      }
      insn->set_src(0, array_reg);
    }
  }

  // Ensure all constants are at beginning of entry block and available to succs
  std::vector<IRInstruction*> consts;
  for (auto&& [lit, reg] : values_to_reg) {
    auto* insn = new IRInstruction(OPCODE_CONST);
    insn->set_dest(reg);
    insn->set_literal(lit);
    consts.emplace_back(insn);
  }
  auto i2 = cfg::InstructionIterable(cfg);
  mutation.insert_before(i2.begin(), consts);
  mutation.flush();
  return pending_new_values.size();
}

void RClassWriter::remap_resource_class_arrays(
    DexStoresVector& stores,
    const std::map<uint32_t, uint32_t>& old_to_remapped_ids) const {
  Timer t("remap_resource_class_arrays");
  // State of all R class fields before remapping
  FieldArrayValues all_field_values;
  RClassReader r_class_reader(m_global_resources_config);

  std::vector<std::pair<DexClass*, FieldArrayValues>> class_states;
  auto scope = build_class_scope(stores);

  r_class_reader.ordered_r_class_iteration(scope, [&](DexClass* cls) {
    TRACE(OPTRES, 3, "remap_resource_class_arrays, considering class %s",
          SHOW(cls));
    auto values = r_class_reader.analyze_clinit(cls, all_field_values);
    class_states.emplace_back(cls, values);
    all_field_values.insert(values.begin(), values.end());
  });

  // Actually remap the values in arrays.
  Scope cleanup_scope;
  for (auto& [cls, field_values] : class_states) {
    cleanup_scope.emplace_back(cls);
    if (!field_values.empty()) {
      auto changes =
          remap_resource_class_clinit(cls, field_values, old_to_remapped_ids);
      TRACE(OPTRES, 2, "Updated %zu field(s) in %s clinit", changes, SHOW(cls));
    }
  }
  // Modifying array values will leave behind old array filling instructions.
  // Perform DCE to clean this up. Avoid doing this when instrumented, since the
  // cleanup scope will not contain callees.
  if (m_global_resources_config.cleanup_r_class_rewriting) {
    TRACE(OPTRES, 2,
          "Cleaning up old array filling instructions on %zu class(es)",
          cleanup_scope.size());
    perform_dce(cleanup_scope);
  } else {
    TRACE(OPTRES, 1, "Skipping cleanup of old array filling instructions!!");
  }
}
} // namespace resources
